import networkx as nx
from tqdm import tqdm

import torch
import numpy as np
from scipy import sparse
from scipy import stats
from scipy.sparse import coo_matrix, csr_matrix, vstack, csc_matrix, coo_array


def bfc_edge(G: nx.Graph, node1: int, node2: int):

    d1 = G.degree[node1]
    d2 = G.degree[node2]
    d_max = max(d1, d2)
    d_min = min(d1, d2)
    if deg_min == 1:
        return 0


    Neighbours_1 = set(G[node1])
    Neighbours_2 = set(G[node2])

    triangles = S1_1.intersection(S1_2)
    
    squares_1 = set(
        idk for idk in Neighbours_1.difference(Neighbours_2) if idk != node2 and set(G[idk]).intersection(Neighbours_2).difference(Neighbours_1.union({node1})))
    squares_2 = set(
        idk for idk in Neighbours_2.difference(Neighbours_1) if idk != node1 and set(G[idk]).intersection(Neighbours_1).difference(Neighbours_2.union({node2})))
    if len(squares_1) == 0 or len(squares_2) == 0:
        return 2 / G.degree[node1] + 2 / G.degree[node2] - 2 + 2 * len(triangles) / d_max + len(triangles) / d_min

    A = nx.adjacency_matrix(G)#.todense()

    gamma = max(max([(A[[idk],:] @ (A[[node2],:] - A[[node1],:].multiply(A[[node2],:])).T)[0, 0] - 1 for idk in squares_1]),
                max([(A[[idk],:] @ (A[[node1],:] - A[[node2],:].multiply(A[[node1],:])).T)[0, 0] - 1 for idk in squares_2]))

    return 2 / d1 + 2 / d2 - 2 + 2 * len(triangles) / d_max + len(
        triangles) / d_min + (1 / (gamma * d_max)) * (len(squares_1) + len(squares_2))


def bfc_nx(G: nx.Graph) :
 
    N = len(G.nodes)
    BFc = torch.zeros(N, N,dtype = torch.float64)

    for node1, node2 in G.edges:
        BFc[node1,node2] = bfc_edge(G, node1, node2)
    return BFc
